﻿using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using Tensorflow;
using Tensorflow.Keras.Utils;
using Tensorflow.NumPy;
using static Tensorflow.Binding;

namespace TensorFlowNET.Examples;

/// <summary>
/// Implement Word2Vec algorithm to compute vector representations of words.
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/word2vec.py
/// </summary>
public class Word2Vec : SciSharpExample, IExample
{
    // Training Parameters
    int batch_size = 128;
    int num_steps = 30000; //3000000;
    int display_step = 1000; //10000;
    int eval_step = 5000;//200000;

    // Evaluation Parameters
    string[] eval_words = new string[] { "five", "of", "going", "hardware", "american", "britain" };
    string[] text_words;
    List<WordId> word2id;
    int[] data;

    // Word2Vec Parameters
    int min_occurrence = 10; // Remove all words that does not appears at least n times
    int skip_window = 3; // How many words to consider left and right
    int num_skips = 2; // How many times to reuse an input to generate a label

    int data_index = 0;
    int top_k = 8; // number of nearest neighbors
    float average_loss = 0;

    public ExampleConfig InitConfig()
        => Config = new ExampleConfig
        {
            Name = "Word2Vec",
            Enabled = true,
            IsImportingGraph = true
        };

    public bool Run()
    {
        tf.compat.v1.disable_eager_execution();

        PrepareData();

        var graph = tf.Graph().as_default();

        tf.train.import_meta_graph($"graph{Path.DirectorySeparatorChar}word2vec.meta");

        // Input data
        Tensor X = graph.OperationByName("Placeholder");
        // Input label
        Tensor Y = graph.OperationByName("Placeholder_1");

        // Compute the average NCE loss for the batch
        Tensor loss_op = graph.OperationByName("Mean");
        // Define the optimizer
        var train_op = graph.OperationByName("GradientDescent");
        Tensor cosine_sim_op = graph.OperationByName("MatMul_1");

        // Initialize the variables (i.e. assign their default value)
        var init = tf.global_variables_initializer();

        var sess = tf.Session(graph);
            // Run the initializer
            sess.run(init);

            var x_test = (from word in eval_words
                          join id in word2id on word equals id.Word into wi
                          from wi2 in wi.DefaultIfEmpty()
                          select wi2 == null ? 0 : wi2.Id).ToArray();

        foreach (var step in range(1, num_steps + 1))
        {
            // Get a new batch of data
            var (batch_x, batch_y) = next_batch(batch_size, num_skips, skip_window);

            (_, float loss) = sess.run((train_op, loss_op), (X, batch_x), (Y, batch_y));
            average_loss += loss;

            if (step % display_step == 0 || step == 1)
            {
                if (step > 1)
                    average_loss /= display_step;

                print($"Step {step}, Average Loss= {average_loss.ToString("F4")}");
                average_loss = 0;
            }

            // Evaluation
            if (step % eval_step == 0 || step == 1)
            {
                print("Evaluation...");
                var sim = sess.run(cosine_sim_op, (X, x_test));
                foreach (var i in range(len(eval_words)))
                {
                    var nearest = np.argsort(0f - sim[i])
                        .ToArray<int>()
                        .Skip(1)
                        .Take(top_k)
                        .ToArray();
                    string log_str = $"\"{eval_words[i]}\" nearest neighbors:";
                    foreach (var k in range(top_k))
                        log_str = $"{log_str} {word2id.First(x => x.Id == nearest[k]).Word},";
                    print(log_str);
                }
            }
        }

        return average_loss < 100;
    }

    // Generate training batch for the skip-gram model
    private (NDArray, NDArray) next_batch(int batch_size, int num_skips, int skip_window)
    {
        var batch = np.ndarray(new Shape(batch_size), dtype: np.int32);
        var labels = np.ndarray((batch_size, 1), dtype: np.int32);
        // get window size (words left and right + current one)
        int span = 2 * skip_window + 1;
        var buffer = new Queue<int>(span);
        if (data_index + span > data.Length)
            data_index = 0;
        data.Skip(data_index).Take(span).ToList().ForEach(x => buffer.Enqueue(x));
        data_index += span;

        foreach (var i in range(batch_size / num_skips))
        {
            var context_words = range(span).Where(x => x != skip_window).ToArray();
            List<int> span_list = Enumerable.Range(0, span).ToList();
            Random rand = new Random(Guid.NewGuid().GetHashCode());
            span_list.RemoveAt(skip_window);
            var words_to_use = span_list.OrderBy(i => rand.Next(0, span_list.Count)).Take(num_skips).ToArray();
            foreach (var (j, context_word) in enumerate(words_to_use))
            {
                batch[i * num_skips + j] = buffer.ElementAt(skip_window);
                labels[i * num_skips + j, 0] = buffer.ElementAt(context_word);
            }

            if (data_index == len(data))
            {
                //buffer.extend(data[0:span]);
                data_index = span;
            }
            else
            {
                buffer.Dequeue();
                buffer.Enqueue(data[data_index]);
                data_index += 1;
            }
        }

        // Backtrack a little bit to avoid skipping words in the end of a batch
        data_index = (data_index + len(data) - span) % len(data);

        return (batch, labels);
    }

    public override void PrepareData()
    {
        // Download graph meta
        var url = "https://github.com/SciSharp/TensorFlow.NET/raw/master/graph/word2vec.meta";
        Web.Download(url, "graph", "word2vec.meta");

        // Download a small chunk of Wikipedia articles collection
        url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/data/text8.zip";
        Web.Download(url, "word2vec", "text8.zip");
        // Unzip the dataset file. Text has already been processed
        Compress.UnZip($"word2vec{Path.DirectorySeparatorChar}text8.zip", "word2vec");

        int wordId = 0;
        text_words = File.ReadAllText($"word2vec{Path.DirectorySeparatorChar}text8").Trim().ToLower().Split();
        // Build the dictionary and replace rare words with UNK token
        word2id = text_words.GroupBy(x => x)
            .Select(x => new WordId
            {
                Word = x.Key,
                Occurrence = x.Count()
            })
            .Where(x => x.Occurrence >= min_occurrence) // Remove samples with less than 'min_occurrence' occurrences
            .OrderByDescending(x => x.Occurrence) // Retrieve the most common words
            .Select(x => new WordId
            {
                Word = x.Word,
                Id = ++wordId, // Assign an id to each word
                Occurrence = x.Occurrence
            })
            .ToList();

        // Retrieve a word id, or assign it index 0 ('UNK') if not in dictionary
        data = (from word in text_words
                join id in word2id on word equals id.Word into wi
                from wi2 in wi.DefaultIfEmpty()
                select wi2 == null ? 0 : wi2.Id).ToArray();

        word2id.Insert(0, new WordId { Word = "UNK", Id = 0, Occurrence = data.Count(x => x == 0) });

        print($"Words count: {text_words.Length}");
        print($"Unique words: {text_words.Distinct().Count()}");
        print($"Vocabulary size: {word2id.Count}");
        print($"Most common words: {string.Join(", ", word2id.Take(10))}");
    }

    private class WordId
    {
        public string Word { get; set; }
        public int Id { get; set; }
        public int Occurrence { get; set; }

        public override string ToString()
        {
            return Word + " " + Id + " " + Occurrence;
        }
    }
}
